import torch
import torch.nn as nn

from .eva_clip_processors import EvaClipImageTrainProcessor
from .eva_vit import Eva2LargePlusEncoder


class EvaClipVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_path = vision_tower
        self.config = VisionTowerConfig()

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = self.config

    def load_model(self):
        self.image_processor = EvaClipImageTrainProcessor(self.config.image_size)
        self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path)
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_feature = self.vision_tower(
                    image.to(device=self.device, dtype=self.dtype).unsqueeze(0)
                ).to(image.dtype)
                image_features.append(image_feature)
        else:
            image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(
                images.dtype
            )

        return image_features

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2


class VisionTowerConfig:
    def __init__(self):
        self.image_size = 336
        self.patch_size = 14
        self.hidden_size = 1024
